{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用决策树预测隐形眼镜类型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**说明:**\n",
    "\n",
    "将数据集文件 'lenses.txt' 放在当前文件夹"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from math import log\n",
    "import operator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 熵的定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def calcShannonEnt(dataSet):\n",
    "    numEntries = len(dataSet)\n",
    "    labelCounts = {}\n",
    "    for featVec in dataSet: #the the number of unique elements and their occurance\n",
    "        currentLabel = featVec[-1]\n",
    "        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0\n",
    "        labelCounts[currentLabel] += 1\n",
    "    shannonEnt = 0.0\n",
    "    for key in labelCounts:\n",
    "        prob = float(labelCounts[key])/numEntries\n",
    "        shannonEnt -= prob * log(prob, 2) #log base 2\n",
    "    return shannonEnt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 划分数据集: 按照给定特征划分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def splitDataSet(dataSet, axis, value):\n",
    "    retDataSet = []\n",
    "    for featVec in dataSet:\n",
    "        if featVec[axis] == value:\n",
    "            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting\n",
    "            reducedFeatVec.extend(featVec[axis+1:])\n",
    "            retDataSet.append(reducedFeatVec)\n",
    "    return retDataSet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 选择最好的数据集划分方式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def chooseBestFeatureToSplit(dataSet):\n",
    "    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels\n",
    "    baseEntropy = calcShannonEnt(dataSet)\n",
    "    bestInfoGain = 0.0; bestFeature = -1\n",
    "    for i in range(numFeatures):        #iterate over all the features\n",
    "        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature\n",
    "        uniqueVals = set(featList)       #get a set of unique values\n",
    "        newEntropy = 0.0\n",
    "        for value in uniqueVals:\n",
    "            subDataSet = splitDataSet(dataSet, i, value)\n",
    "            prob = len(subDataSet)/float(len(dataSet))\n",
    "            newEntropy += prob * calcShannonEnt(subDataSet)\n",
    "        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy\n",
    "        if (infoGain > bestInfoGain):       #compare this to the best gain so far\n",
    "            bestInfoGain = infoGain         #if better than current best, set to best\n",
    "            bestFeature = i\n",
    "    return bestFeature    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 多数表决法决定该叶子节点的分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def majorityCnt(classList):\n",
    "    classCount={}\n",
    "    for vote in classList:\n",
    "        if vote not in classCount.keys(): classCount[vote] = 0\n",
    "        classCount[vote] += 1\n",
    "    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)\n",
    "    return sortedClassCount[0][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 递归构建决策树"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def createTree(dataSet, labels):\n",
    "    classList = [example[-1] for example in dataSet]\n",
    "    if classList.count(classList[0]) == len(classList):\n",
    "        return classList[0]#stop splitting when all of the classes are equal\n",
    "    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet\n",
    "        return majorityCnt(classList)\n",
    "    bestFeat = chooseBestFeatureToSplit(dataSet)\n",
    "    bestFeatLabel = labels[bestFeat]\n",
    "    myTree = {bestFeatLabel:{}}\n",
    "    del(labels[bestFeat])\n",
    "    featValues = [example[bestFeat] for example in dataSet]\n",
    "    uniqueVals = set(featValues)\n",
    "    for value in uniqueVals:\n",
    "        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels\n",
    "        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)\n",
    "    return myTree"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用决策树执行分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def classify(inputTree, featLabels, testVec):\n",
    "    firstStr = list(inputTree)[0]\n",
    "    secondDict = inputTree[firstStr]\n",
    "    featIndex = featLabels.index(firstStr)\n",
    "    key = testVec[featIndex]\n",
    "    valueOfFeat = secondDict[key]\n",
    "    if isinstance(valueOfFeat, dict):\n",
    "        classLabel = classify(valueOfFeat, featLabels, testVec)\n",
    "    else: classLabel = valueOfFeat\n",
    "    return classLabel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用决策树预测隐形眼镜类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'no': {'age': {'pre': 'soft', 'young': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}}}, 'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'young': 'hard', 'presbyopic': 'no lenses'}}, 'myope': 'hard'}}}}}}\n"
     ]
    }
   ],
   "source": [
    "fr = open('lenses.txt')\n",
    "lenses = [inst.strip().split('\\t') for inst in fr.readlines()]\n",
    "lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']\n",
    "lensesTree = createTree(lenses,lensesLabels)\n",
    "print(lensesTree)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用决策树模型进行预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'no lenses'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']\n",
    "classify(lensesTree, lensesLabels, lenses[0][:-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对 lenses 数据集所有数据进行决策树分类预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['no lenses', 'soft', 'no lenses', 'hard', 'no lenses', 'soft', 'no lenses', 'hard', 'no lenses', 'soft', 'no lenses', 'hard', 'no lenses', 'soft', 'no lenses', 'no lenses', 'no lenses', 'no lenses', 'no lenses', 'hard', 'no lenses', 'soft', 'no lenses', 'no lenses']\n"
     ]
    }
   ],
   "source": [
    "preds = []\n",
    "for i in range(len(lenses)):\n",
    "    pred = classify(lensesTree, lensesLabels, lenses[i][:-1])\n",
    "    preds.append(pred)\n",
    "print(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
